{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8715.09it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4671.09it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 753780.90it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 124006.70it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 78196.55it/s] \n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 303171.66it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 672512.12it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 2604778.31it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8568.80it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4695.30it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 110666.71it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 158182.62it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 114733.17it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 395844.21it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 602682.10it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 112560.96it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 88610.30it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 7976.57it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4836.02it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 121345.33it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 101528.79it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 113992.32it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 82840.39it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 504819.62it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 64349.23it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 76349.71it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 9011.99it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4557.39it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 123811.53it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 87136.22it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 109487.00it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 171122.40it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 556209.59it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 102850.50it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 16167.00it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to MatchLSTM.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 40)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             1667400     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "lstm_left (LSTM)                (None, 10, 100)      80400       embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lstm_right (LSTM)               (None, 40, 100)      80400       embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, 10, 100)      0           lstm_left[0][0]                  \n",
      "                                                                 lstm_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 50, 100)      0           lambda_1[0][0]                   \n",
      "                                                                 lstm_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "lstm_merge (LSTM)               (None, 200)          240800      concatenate_1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 200)          0           lstm_merge[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_5 (Dense)                 (None, 100)          20100       dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_6 (Dense)                 (None, 1)            101         dense_5[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 2,089,201\n",
      "Trainable params: 2,089,201\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.contrib.models.MatchLSTM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 100\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['fc_num_units'] = 100\n",
    "model.params['lstm_num_units'] = 100\n",
    "model.params['dropout_rate'] = 0.5\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=100)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 21s 201ms/step - loss: 0.8767\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5099543712381165 - normalized_discounted_cumulative_gain@5(0): 0.570930778113466 - mean_average_precision(0): 0.5375570266858428\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 18s 181ms/step - loss: 0.7819\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5837647495178753 - normalized_discounted_cumulative_gain@5(0): 0.6394638886458646 - mean_average_precision(0): 0.6040718458067675\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 18s 181ms/step - loss: 0.6576\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5749813227815257 - normalized_discounted_cumulative_gain@5(0): 0.6368413359562173 - mean_average_precision(0): 0.593478439599065\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 18s 180ms/step - loss: 0.5421\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5863836869866456 - normalized_discounted_cumulative_gain@5(0): 0.6591282491862374 - mean_average_precision(0): 0.6091608900944344\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 18s 179ms/step - loss: 0.4883\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6099606276322094 - normalized_discounted_cumulative_gain@5(0): 0.6719634966454419 - mean_average_precision(0): 0.6277677491285086\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 18s 180ms/step - loss: 0.3939\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5965922615581667 - normalized_discounted_cumulative_gain@5(0): 0.6642886302630996 - mean_average_precision(0): 0.6175022457273016\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 18s 179ms/step - loss: 0.3494\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5969163327844941 - normalized_discounted_cumulative_gain@5(0): 0.654980490270505 - mean_average_precision(0): 0.6048968512389957\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 18s 179ms/step - loss: 0.2574\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5882465905736889 - normalized_discounted_cumulative_gain@5(0): 0.648998229758532 - mean_average_precision(0): 0.6056487050554731\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 18s 180ms/step - loss: 0.2099\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5812732137174904 - normalized_discounted_cumulative_gain@5(0): 0.6385733737993456 - mean_average_precision(0): 0.5904282837060495\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 18s 176ms/step - loss: 0.1546\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5751827504118617 - normalized_discounted_cumulative_gain@5(0): 0.6381139502232698 - mean_average_precision(0): 0.5875240500240501\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 18s 176ms/step - loss: 0.1145\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5792371607274251 - normalized_discounted_cumulative_gain@5(0): 0.6314012936935991 - mean_average_precision(0): 0.5839661712308527\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 18s 177ms/step - loss: 0.0797\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5806102491805909 - normalized_discounted_cumulative_gain@5(0): 0.6371007803777863 - mean_average_precision(0): 0.5909281842919222\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 18s 178ms/step - loss: 0.0590\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5796805819552002 - normalized_discounted_cumulative_gain@5(0): 0.6389773331883829 - mean_average_precision(0): 0.5917544806182773\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 18s 177ms/step - loss: 0.0450\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5753008332410817 - normalized_discounted_cumulative_gain@5(0): 0.6497273394999601 - mean_average_precision(0): 0.597743832335477\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 18s 176ms/step - loss: 0.0293\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5723965263215952 - normalized_discounted_cumulative_gain@5(0): 0.6391558300930712 - mean_average_precision(0): 0.592764774489458\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 18s 177ms/step - loss: 0.0247\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5676736627660003 - normalized_discounted_cumulative_gain@5(0): 0.633191440835839 - mean_average_precision(0): 0.5847983094027398\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 18s 176ms/step - loss: 0.0216\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5525329426472332 - normalized_discounted_cumulative_gain@5(0): 0.6309600048612133 - mean_average_precision(0): 0.5766039731230829\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 18s 176ms/step - loss: 0.0081\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5744551426785978 - normalized_discounted_cumulative_gain@5(0): 0.6287229218645893 - mean_average_precision(0): 0.5812227153999306\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 19s 189ms/step - loss: 0.0081\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5722612989112447 - normalized_discounted_cumulative_gain@5(0): 0.6295674189994438 - mean_average_precision(0): 0.5801665985210289\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 30s 290ms/step - loss: 0.0071\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5720078118989034 - normalized_discounted_cumulative_gain@5(0): 0.629474017975712 - mean_average_precision(0): 0.5818016166908571\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 35s 344ms/step - loss: 0.0032\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5779640888091648 - normalized_discounted_cumulative_gain@5(0): 0.6402824260197859 - mean_average_precision(0): 0.5865265536151612\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 34s 332ms/step - loss: 0.0036\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5951454083750933 - normalized_discounted_cumulative_gain@5(0): 0.6452503109234422 - mean_average_precision(0): 0.5949222881789895\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 34s 329ms/step - loss: 0.0079\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5674509389409547 - normalized_discounted_cumulative_gain@5(0): 0.6299302092242881 - mean_average_precision(0): 0.578655212949017\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 26s 251ms/step - loss: 0.0039\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5862129731591318 - normalized_discounted_cumulative_gain@5(0): 0.6324312545659629 - mean_average_precision(0): 0.5876746473848405\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 20s 197ms/step - loss: 0.0098\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5680726528158403 - normalized_discounted_cumulative_gain@5(0): 0.6277027487171363 - mean_average_precision(0): 0.5807085583193178\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 20s 200ms/step - loss: 0.0024\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5782369586725221 - normalized_discounted_cumulative_gain@5(0): 0.634863735624802 - mean_average_precision(0): 0.5873109795391998\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 21s 202ms/step - loss: 0.0055\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5812718334666404 - normalized_discounted_cumulative_gain@5(0): 0.633128110764345 - mean_average_precision(0): 0.5870939592855374\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 20s 195ms/step - loss: 0.0027\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5653059740396857 - normalized_discounted_cumulative_gain@5(0): 0.6233327598011974 - mean_average_precision(0): 0.575225047972628\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 20s 200ms/step - loss: 0.0063\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5835418775394088 - normalized_discounted_cumulative_gain@5(0): 0.6341004674725141 - mean_average_precision(0): 0.5880507499198618\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 20s 194ms/step - loss: 0.0043\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5598769325570103 - normalized_discounted_cumulative_gain@5(0): 0.6206733618765706 - mean_average_precision(0): 0.57247561145998\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=5, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
